{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM 做词性预测\n",
    "前面我们讲了词嵌入以及 n-gram 模型做单词预测，但是目前还没有用到 RNN，在最后这一次课中，我们会结合前面讲的所有预备知识，教大家如何使用 LSTM 来做词性预测。\n",
    "\n",
    "## 模型介绍\n",
    "对于一个单词，会有这不同的词性，首先能够根据一个单词的后缀来初步判断，比如 -ly 这种后缀，很大概率是一个副词，除此之外，一个相同的单词可以表示两种不同的词性，比如 book 既可以表示名词，也可以表示动词，所以到底这个词是什么词性需要结合前后文来具体判断。\n",
    "\n",
    "根据这个问题，我们可以使用 lstm 模型来进行预测，首先对于一个单词，可以将其看作一个序列，比如 apple 是由 a p p l e 这 5 个单词构成，这就形成了 5 的序列，我们可以对这些字符构建词嵌入，然后输入 lstm，就像 lstm 做图像分类一样，只取最后一个输出作为预测结果，整个单词的字符串能够形成一种记忆的特性，帮助我们更好的预测词性。\n",
    "\n",
    "![](https://ws3.sinaimg.cn/large/006tKfTcgy1fmxi67w0f7j30ap05qq2u.jpg)\n",
    "\n",
    "接着我们把这个单词和其前面几个单词构成序列，可以对这些单词构建新的词嵌入，最后输出结果是单词的词性，也就是根据前面几个词的信息对这个词的词性进行分类。\n",
    "\n",
    "下面我们用例子来简单的说明"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使用下面简单的训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "training_data = [(\"The dog ate the apple\".split(),\n",
    "                  [\"DET\", \"NN\", \"V\", \"DET\", \"NN\"]),\n",
    "                 (\"Everybody read that book\".split(), \n",
    "                  [\"NN\", \"V\", \"DET\", \"NN\"])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们需要对单词和标签进行编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_to_idx = {}\n",
    "tag_to_idx = {}\n",
    "for context, tag in training_data:\n",
    "    for word in context:\n",
    "        if word.lower() not in word_to_idx:\n",
    "            word_to_idx[word.lower()] = len(word_to_idx)\n",
    "    for label in tag:\n",
    "        if label.lower() not in tag_to_idx:\n",
    "            tag_to_idx[label.lower()] = len(tag_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'apple': 3,\n",
       " 'ate': 2,\n",
       " 'book': 7,\n",
       " 'dog': 1,\n",
       " 'everybody': 4,\n",
       " 'read': 5,\n",
       " 'that': 6,\n",
       " 'the': 0}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_to_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'det': 0, 'nn': 1, 'v': 2}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tag_to_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们对字母进行编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "alphabet = 'abcdefghijklmnopqrstuvwxyz'\n",
    "char_to_idx = {}\n",
    "for i in range(len(alphabet)):\n",
    "    char_to_idx[alphabet[i]] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': 0,\n",
       " 'b': 1,\n",
       " 'c': 2,\n",
       " 'd': 3,\n",
       " 'e': 4,\n",
       " 'f': 5,\n",
       " 'g': 6,\n",
       " 'h': 7,\n",
       " 'i': 8,\n",
       " 'j': 9,\n",
       " 'k': 10,\n",
       " 'l': 11,\n",
       " 'm': 12,\n",
       " 'n': 13,\n",
       " 'o': 14,\n",
       " 'p': 15,\n",
       " 'q': 16,\n",
       " 'r': 17,\n",
       " 's': 18,\n",
       " 't': 19,\n",
       " 'u': 20,\n",
       " 'v': 21,\n",
       " 'w': 22,\n",
       " 'x': 23,\n",
       " 'y': 24,\n",
       " 'z': 25}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "char_to_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们可以构建训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def make_sequence(x, dic): # 字符编码\n",
    "    idx = [dic[i.lower()] for i in x]\n",
    "    idx = torch.LongTensor(idx)\n",
    "    return idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "  0\n",
       " 15\n",
       " 15\n",
       " 11\n",
       "  4\n",
       "[torch.LongTensor of size 5]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_sequence('apple', char_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Everybody', 'read', 'that', 'book']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_data[1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       " 4\n",
       " 5\n",
       " 6\n",
       " 7\n",
       "[torch.LongTensor of size 4]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_sequence(training_data[1][0], word_to_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建单个字符的 lstm 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class char_lstm(nn.Module):\n",
    "    def __init__(self, n_char, char_dim, char_hidden):\n",
    "        super(char_lstm, self).__init__()\n",
    "        \n",
    "        self.char_embed = nn.Embedding(n_char, char_dim)\n",
    "        self.lstm = nn.LSTM(char_dim, char_hidden)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.char_embed(x)\n",
    "        out, _ = self.lstm(x)\n",
    "        return out[-1] # (batch, hidden)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建词性分类的 lstm 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class lstm_tagger(nn.Module):\n",
    "    def __init__(self, n_word, n_char, char_dim, word_dim, \n",
    "                 char_hidden, word_hidden, n_tag):\n",
    "        super(lstm_tagger, self).__init__()\n",
    "        self.word_embed = nn.Embedding(n_word, word_dim)\n",
    "        self.char_lstm = char_lstm(n_char, char_dim, char_hidden)\n",
    "        self.word_lstm = nn.LSTM(word_dim + char_hidden, word_hidden)\n",
    "        self.classify = nn.Linear(word_hidden, n_tag)\n",
    "        \n",
    "    def forward(self, x, word):\n",
    "        char = []\n",
    "        for w in word: # 对于每个单词做字符的 lstm\n",
    "            char_list = make_sequence(w, char_to_idx)\n",
    "            char_list = char_list.unsqueeze(1) # (seq, batch, feature) 满足 lstm 输入条件\n",
    "            char_infor = self.char_lstm(Variable(char_list)) # (batch, char_hidden)\n",
    "            char.append(char_infor)\n",
    "        char = torch.stack(char, dim=0) # (seq, batch, feature)\n",
    "        \n",
    "        x = self.word_embed(x) # (batch, seq, word_dim)\n",
    "        x = x.permute(1, 0, 2) # 改变顺序\n",
    "        x = torch.cat((x, char), dim=2) # 沿着特征通道将每个词的词嵌入和字符 lstm 输出的结果拼接在一起\n",
    "        x, _ = self.word_lstm(x)\n",
    "        \n",
    "        s, b, h = x.shape\n",
    "        x = x.view(-1, h) # 重新 reshape 进行分类线性层\n",
    "        out = self.classify(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = lstm_tagger(len(word_to_idx), len(char_to_idx), 10, 100, 50, 128, len(tag_to_idx))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 50, Loss: 0.86690\n",
      "Epoch: 100, Loss: 0.65471\n",
      "Epoch: 150, Loss: 0.45582\n",
      "Epoch: 200, Loss: 0.30351\n",
      "Epoch: 250, Loss: 0.20446\n",
      "Epoch: 300, Loss: 0.14376\n"
     ]
    }
   ],
   "source": [
    "# 开始训练\n",
    "for e in range(300):\n",
    "    train_loss = 0\n",
    "    for word, tag in training_data:\n",
    "        word_list = make_sequence(word, word_to_idx).unsqueeze(0) # 添加第一维 batch\n",
    "        tag = make_sequence(tag, tag_to_idx)\n",
    "        word_list = Variable(word_list)\n",
    "        tag = Variable(tag)\n",
    "        # 前向传播\n",
    "        out = net(word_list, word)\n",
    "        loss = criterion(out, tag)\n",
    "        train_loss += loss.data[0]\n",
    "        # 反向传播\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    if (e + 1) % 50 == 0:\n",
    "        print('Epoch: {}, Loss: {:.5f}'.format(e + 1, train_loss / len(training_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后我们可以看看预测的结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = net.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sent = 'Everybody ate the apple'\n",
    "test = make_sequence(test_sent.split(), word_to_idx).unsqueeze(0)\n",
    "out = net(Variable(test), test_sent.split())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "-1.2148  1.9048 -0.6570\n",
      "-0.9272 -0.4441  1.4009\n",
      " 1.6425 -0.7751 -1.1553\n",
      "-0.6121  1.6036 -1.1280\n",
      "[torch.FloatTensor of size 4x3]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'det': 0, 'nn': 1, 'v': 2}\n"
     ]
    }
   ],
   "source": [
    "print(tag_to_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后可以得到上面的结果，因为最后一层的线性层没有使用 softmax，所以数值不太像一个概率，但是每一行数值最大的就表示属于该类，可以看到第一个单词 'Everybody' 属于 nn，第二个单词 'ate' 属于 v，第三个单词 'the' 属于det，第四个单词 'apple' 属于 nn，所以得到的这个预测结果是正确的"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
